library(tidyverse)
library(tidymodels)

train_model <- function(data, model_spec, outcome) {
   # creates workflow with mode, recipe and then fit that model on given data
  recipe <-  make_matrix_recipe(data, {{ outcome }})
  (workflow()
   %>% add_model(model_spec)
   %>% add_recipe(recipe)
   %>% fit(data)
  )
}

get_xgboost_spec <- function() {
   # helper function to define base xgboost model
   (boost_tree(mode = "classification")
    %>% set_engine("xgboost", booster = "gbtree", objective = "binary:logistic")
   )
}

get_model_preds <- function(model, data, true_outcome) {
   # helper function to receive preds based on given data and model 
  (predict(model, data, type = "prob")
   %>% bind_cols(data %>% select(id_var,S_index_date_XX, starts_with("EVE"), true_value = {{ true_outcome }}, ))
  )
}

get_feature_importance <- function(model, proportion, stratify_by) {
   # extract feature importance from the model
  (model
   %>% pull_workflow_fit()
   %>% pluck("fit")
   %>% xgboost::xgb.importance(model = .)
  )
}

create_hashed_split <- function(data, proportion) {
  modulo_type <- as.integer(1/proportion)
  # Get the first charachter from hushed id (using md5), to be split on later
  set.seed(3)
  (data <- 
     data
     %>% mutate(hashed_id_num = gsub("[^0-9]", "", openssl::md5(as.character(id_var)))
                                %>% substr(1,4)
                                %>% as.numeric(),
                data_set = if_else(hashed_id_num%%modulo_type != 0, "test", "train"))
     %>% group_by(data_set)
  )
  # return list with named subsamples
  data %>% group_split() %>% set_names(group_keys(data)[[1]])
}

downsample_data <- function(data, outcome) {
   train_prop_died <-  sum(pull(data, {{ outcome }}) == "1")/nrow(data)
    
   (data
    %>% filter({{ outcome }} == "1" | 
               hashed_id_num <= (train_prop_died/(1-train_prop_died))*10000)
   )
}

get_calibration_proportion <- function(data, outcome) {
   nrow(filter(data, {{ outcome}} == "0"))/nrow(filter(data, {{ outcome}} == "1"))
}

calibrate_preds <- function(preds_data, calibration_proportion, preds_col) {
   (preds_data
    %>% mutate(preds_after_bayes = {{ preds_col }} / (calibration_proportion-(calibration_proportion-1)*{{ preds_col }}))
   )
}

get_xgboost_parameters_grid <- function(df, outcome, grid_size) {
   # tibble, varname, int -> tibble
   # helper for grid search
   # takes data, outcome, and grid size and returns tibble with parameters grid
   
   # get data for grid
   df <- make_matrix_recipe(df, {{ outcome }}) %>% prep() %>% juice() 
   print("preparing parameters grid...")
   grid_random(
      trees(c(300L, 400L)),      
      tree_depth(c(1,10)),
      learn_rate(c(0.01, 0.3)),
      loss_reduction(c(0, 20), trans = NULL),
      mtry()                   %>% finalize(df),
      min_n(c(0, 20))          %>% finalize(df), 
      sample_size(c(0.5,0.8))  %>% finalize(df), 
      size = grid_size
   )
}

get_specs_from_grid <- function(parameters_grid, base_model_spec) {
   # tibble, spec -> list(spec)
   # helper for grid search
   # takes grid and base model, splits grid to list of rows and creates list of specs per row in the grid
   print("preparing xgboost specs from grid...")
   (parameters_grid
    %>% rowwise()
    %>% group_split()
    %>% map(~base_model_spec %>% set_args(!!!.x))
   )
}

run_grid_search <- function(specs, df, outcome) {
   
   train_eval_helper <- function(model_spec, n) {
      # train-val split 
      print(paste0(Sys.time(), ": Start training model number ", n))
      # evaluate model with given spec, test and train and returns roc_auc
      (recipe <- 
         train
         %>% recipes::recipe(DMG_died_within_365d ~ .) 
         %>% update_role({{ outcome }}, new_role = "outcome")
      )
      (workflow()
       %>% add_model(model_spec)
       %>% add_recipe(recipe)
       %>% fit(train)
       %>% predict(test, type = "prob")
       %>% bind_cols(test %>% select(true_value = {{ outcome }}))
       %>% roc_auc(truth = true_value, .pred_1) 
       %>% pull(.estimate) %T>% print()
      )
   }
   
   print("Preparing data for the grid search..")
   train_val_split <- initial_split(df, 0.8, strata = id_var)
   recipe <- make_matrix_recipe(training(train_val_split), {{ outcome }}) %>% prep()
   train <- juice(recipe)
   test <- recipe %>% bake(testing(train_val_split))
   print("Starting grid search...")
   roc_auc_list <- imap_dbl(specs, train_eval_helper)
   tibble(roc_auc = roc_auc_list,
          model_specs = specs)
}


find_best_parameters <- function(df, outcome, base_model_spec, grid_size) {
   (get_xgboost_parameters_grid(df,  {{ outcome }}, grid_size)
    %>% get_specs_from_grid(base_model_spec)
    %>% run_grid_search(df, {{ outcome }})
    %>% filter(roc_auc == max(roc_auc)) 
   )
}

hashed_split_by_proportion <- function(data, proportion) {
   get_row_nums <- function(subset) {
      (data
       %>% mutate(row_num = row_number())
       %>% filter(hashed_id_num %in% subset)
       %>% pull(row_num) 
       %>% as.integer()
       )
   }
   
   unique_hashed_ids <- data$hashed_id_num %>% unique()
   split_len <- length(unique_hashed_ids)*proportion
   train_ids <- unique_hashed_ids %>% sample(split_len)
   test_ids <- setdiff(unique_hashed_ids, train_ids)
   train_rows <- get_row_nums(train_ids)
   test_rows <- get_row_nums(test_ids)
   rsample::make_splits(list(analysis = train_rows, assessment = test_rows), data)
}


run_bayes_optimisation <- function(data, outcome, initial, n_iters) {
   set.seed(20)
   test_train_split <- hashed_split_by_proportion(data, 0.8) %>% rsample::new_rset(splits = list(.), id = "sample", subclass = "rset")
   recipe <-  make_matrix_recipe(data, {{ outcome }})
   
   (model_spec <- 
      boost_tree(mode = "classification")
      %>% set_engine("xgboost", booster = "gbtree", objective = "binary:logistic")
      %>% set_args(nthread = 40,
                   trees = tune(),
                   tree_depth = tune(), 
                   learn_rate = tune(),
                   loss_reduction = tune(), 
                   min_n = tune(),
                   sample_size = tune(),
                   mtry = tune(),
                   silent = 1)
   )
   
   (xgb_workflow <- 
      workflow()
      %>% add_model(model_spec)
      %>% add_recipe(recipe)
   )
   
   (params <- 
      parameters(xgb_workflow)
      %>% update(tree_depth = tree_depth(c(1,10)),
                 trees = trees(c(200L, 400L)),
                 learn_rate = learn_rate(c(0.01, 0.3), trans = NULL),
                 loss_reduction = loss_reduction(c(0, 20), trans = NULL),
                 mtry = mtry()                     %>% finalize(data),
                 sample_size = sample_prop(c(0.7, 1)),
                 min_n = min_n(c(0, 20))           %>% finalize(data))
   )
   
   if (!is.numeric(initial)) {
      print("Prepare initial grid with given set of the parameters")
      (initial <- params
       %>% dials::grid_latin_hypercube(size = length(initial)) %>% bind_rows(initial)
       %>% tune_grid(xgb_workflow, resamples = test_train_split, grid = .)
      )
   }
                 
   (xgb_workflow
      %>% tune_bayes(resamples = test_train_split,
                     param_info = params,
                     initial = initial,
                     iter = n_iters,
                     metrics = metric_set(roc_auc),
                     control = control_bayes(no_improve = 30, verbose = TRUE))
   )
}

